{
 "cells": [
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "# Fine-tuning DTrOCR on IAM dataset\n",
    "This is an example of fine-tuning DTrOCR on IAM dataset handwritten words from [Kaggle](https://www.kaggle.com/datasets/teykaicong/iamondb-handwriting-dataset). IAM Aachen splits can be downloaded [here](https://www.openslr.org/56/)."
   ],
   "id": "4cda29000b17cc0a"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "# Dataset folder structure\n",
    "```\n",
    "iam_words/\n",
    "│\n",
    "├── words/                              # Folder containing word images as PNGs\n",
    "│   ├── a01/                            # First folder\n",
    "│   │   ├── a01-000u/\n",
    "│   │   │   ├── a01-000u-00-00.png\n",
    "│   │   │   └── a01-000u-00-01.png\n",
    "│   │   .\n",
    "│   │   .\n",
    "│   │   .\n",
    "│   └── r06/                            # Last folder\n",
    "│       ├── r06-000/\n",
    "│       │   ├── r06-000-00-00.png\n",
    "│       │   └── r06-000-00-01.png\n",
    "│\n",
    "├── xml/                                # XML files\n",
    "│\t├── a01-000u.xml\n",
    "│\t.\n",
    "│\t.\n",
    "│\t.\n",
    "│\t└── r06-143.xml\n",
    "│\n",
    "└── splits/                             # IAM Aachen splits\n",
    "    ├── train.uttlist\n",
    "    ├── validation.uttlist\n",
    "    └── test.uttlist\n",
    "```"
   ],
   "id": "672df8f4f58440b7"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Build lists of images and texts",
   "id": "89f83c2b6af325eb"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "import glob\n",
    "from pathlib import Path\n",
    "\n",
    "dataset_path = Path('/home/arvind/datasets/iam_words')\n",
    "\n",
    "xml_files = sorted(glob.glob(str(dataset_path / 'xml' / '*.xml')))\n",
    "word_image_files = sorted(glob.glob(str(dataset_path / 'words' / '**' / '*.png'), recursive=True))\n",
    "\n",
    "print(f\"{len(xml_files)} XML files and {len(word_image_files)} word image files\")"
   ],
   "id": "fa6ad879545d49c7",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "import tqdm\n",
    "import multiprocessing as mp\n",
    "import xml.etree.ElementTree as ET\n",
    "\n",
    "from PIL import Image\n",
    "from dataclasses import dataclass\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class Word:\n",
    "    id: str\n",
    "    file_path: Path\n",
    "    writer_id: str\n",
    "    transcription: str\n",
    "\n",
    "def get_words_from_xml(xml_file):\n",
    "    tree = ET.parse(xml_file)\n",
    "    root = tree.getroot()\n",
    "    \n",
    "    root_id = root.get('id')\n",
    "    writer_id = root.get('writer-id')\n",
    "    xml_words = []\n",
    "    for line in root.findall('handwritten-part')[0].findall('line'):\n",
    "        for word in line.findall('word'):\n",
    "            image_file = Path([f for f in word_image_files if f.endswith(word.get('id') + '.png')][0])\n",
    "            try:\n",
    "                with Image.open(image_file) as _:\n",
    "                    xml_words.append(\n",
    "                        Word(\n",
    "                            id=root_id,\n",
    "                            file_path=image_file,\n",
    "                            writer_id=writer_id,\n",
    "                            transcription=word.get('text')\n",
    "                        )\n",
    "                    )\n",
    "            except Exception:\n",
    "                pass\n",
    "            \n",
    "    return xml_words\n",
    "\n",
    "with mp.Pool(processes=mp.cpu_count()) as pool:\n",
    "    words_from_xmls = list(\n",
    "        tqdm.tqdm(\n",
    "            pool.imap(get_words_from_xml, xml_files), \n",
    "            total=len(xml_files),\n",
    "            desc='Building dataset'\n",
    "        )\n",
    "    )\n",
    "\n",
    "words = [word for words in words_from_xmls for word in words]"
   ],
   "id": "7633dafd56363d1c",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Train test split",
   "id": "d0ea1601f0dabbde"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "with open('/home/arvind/datasets/iam_words/splits/train.uttlist') as fp:\n",
    "    train_ids = [line.replace('\\n', '') for line in fp.readlines()]\n",
    "\n",
    "with open('/home/arvind/datasets/iam_words/splits/test.uttlist') as fp:\n",
    "    test_ids = [line.replace('\\n', '') for line in fp.readlines()]\n",
    "\n",
    "with open('/home/arvind/datasets/iam_words/splits/validation.uttlist') as fp:\n",
    "    validation_ids = [line.replace('\\n', '') for line in fp.readlines()]\n",
    "\n",
    "print(f\"Train size: {len(train_ids)}; Validation size: {len(validation_ids)}; Test size: {len(test_ids)}\")"
   ],
   "id": "c5c924f758cc73a4",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "train_word_records = [word for word in words if word.id in train_ids]\n",
    "validation_word_records = [word for word in words if word.id in validation_ids]\n",
    "test_word_records = [word for word in words if word.id in test_ids]\n",
    "\n",
    "print(f'Train size: {len(train_word_records)}; Validation size: {len(validation_word_records)}; Test size: {len(test_word_records)}')"
   ],
   "id": "b5433445168c7f98",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Build dataset and dataloader",
   "id": "ce1d6d01cac2ef35"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "from dtrocr.processor import DTrOCRProcessor\n",
    "from dtrocr.config import DTrOCRConfig\n",
    "\n",
    "from torch.utils.data import Dataset\n",
    "\n",
    "class IAMDataset(Dataset):\n",
    "    def __init__(self, words: list[Word], config: DTrOCRConfig):\n",
    "        super(IAMDataset, self).__init__()\n",
    "        self.words = words\n",
    "        self.processor = DTrOCRProcessor(config, add_eos_token=True, add_bos_token=True)\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.words)\n",
    "    \n",
    "    def __getitem__(self, item):\n",
    "        inputs = self.processor(\n",
    "            images=Image.open(self.words[item].file_path).convert('RGB'),\n",
    "            texts=self.words[item].transcription,\n",
    "            padding='max_length',\n",
    "            return_tensors=\"pt\",\n",
    "            return_labels=True,\n",
    "        )\n",
    "        return {\n",
    "            'pixel_values': inputs.pixel_values[0],\n",
    "            'input_ids': inputs.input_ids[0],\n",
    "            'attention_mask': inputs.attention_mask[0],\n",
    "            'labels': inputs.labels[0]\n",
    "        }\n",
    "\n",
    "config = DTrOCRConfig(\n",
    "    # attn_implementation='flash_attention_2'\n",
    ")\n",
    "\n",
    "train_data = IAMDataset(words=train_word_records, config=config)\n",
    "validation_data = IAMDataset(words=validation_word_records, config=config)\n",
    "test_data = IAMDataset(words=test_word_records, config=config)"
   ],
   "id": "cfc14f60f4a160e6",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "from torch.utils.data import DataLoader\n",
    "\n",
    "train_dataloader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=mp.cpu_count())\n",
    "validation_dataloader = DataLoader(validation_data, batch_size=32, shuffle=False, num_workers=mp.cpu_count())\n",
    "test_dataloader = DataLoader(test_data, batch_size=32, shuffle=False, num_workers=mp.cpu_count())"
   ],
   "id": "3f9717df54449a10",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Model",
   "id": "c460aa9a2caa3af6"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "import torch\n",
    "torch.set_float32_matmul_precision('high')\n",
    "\n",
    "from dtrocr.model import DTrOCRLMHeadModel\n",
    "\n",
    "model = DTrOCRLMHeadModel(config)\n",
    "model = torch.compile(model)\n",
    "model.to(device=0)"
   ],
   "id": "18096c11905a980e",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Training",
   "id": "2a10a9feb1801174"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "from typing import Tuple\n",
    "\n",
    "def evaluate_model(model: torch.nn.Module, dataloader: DataLoader) -> Tuple[float, float]:\n",
    "    # set model to evaluation mode\n",
    "    model.eval()\n",
    "    \n",
    "    losses, accuracies = [], []\n",
    "    with torch.no_grad():\n",
    "        for inputs in tqdm.tqdm(dataloader, total=len(dataloader), desc=f'Evaluating test set'):\n",
    "            inputs = send_inputs_to_device(inputs, device=0)\n",
    "            outputs = model(**inputs)\n",
    "            \n",
    "            losses.append(outputs.loss.item())\n",
    "            accuracies.append(outputs.accuracy.item())\n",
    "    \n",
    "    loss = sum(losses) / len(losses)\n",
    "    accuracy = sum(accuracies) / len(accuracies)\n",
    "    \n",
    "    # set model back to training mode\n",
    "    model.train()\n",
    "    \n",
    "    return loss, accuracy\n",
    "\n",
    "def send_inputs_to_device(dictionary, device):\n",
    "    return {key: value.to(device=device) if isinstance(value, torch.Tensor) else value for key, value in dictionary.items()}\n",
    "\n",
    "use_amp = True\n",
    "scaler = torch.cuda.amp.GradScaler(enabled=use_amp)\n",
    "optimiser = torch.optim.Adam(params=model.parameters(), lr=1e-4)\n",
    "\n",
    "EPOCHS = 50\n",
    "train_losses, train_accuracies = [], []\n",
    "validation_losses, validation_accuracies = [], []\n",
    "for epoch in range(EPOCHS):\n",
    "    epoch_losses, epoch_accuracies = [], []\n",
    "    for inputs in tqdm.tqdm(train_dataloader, total=len(train_dataloader), desc=f'Epoch {epoch + 1}'):\n",
    "        \n",
    "        # set gradients to zero\n",
    "        optimiser.zero_grad()\n",
    "        \n",
    "        # send inputs to same device as model\n",
    "        inputs = send_inputs_to_device(inputs, device=0)\n",
    "        \n",
    "        # forward pass\n",
    "        with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp):\n",
    "            outputs = model(**inputs)\n",
    "        \n",
    "        # calculate gradients\n",
    "        scaler.scale(outputs.loss).backward()\n",
    "        \n",
    "        # update weights\n",
    "        scaler.step(optimiser)\n",
    "        scaler.update()\n",
    "        \n",
    "        epoch_losses.append(outputs.loss.item())\n",
    "        epoch_accuracies.append(outputs.accuracy.item())\n",
    "        \n",
    "    # store loss and metrics\n",
    "    train_losses.append(sum(epoch_losses) / len(epoch_losses))\n",
    "    train_accuracies.append(sum(epoch_accuracies) / len(epoch_accuracies))\n",
    "    \n",
    "    # tests loss and accuracy\n",
    "    validation_loss, validation_accuracy = evaluate_model(model, validation_dataloader)\n",
    "    validation_losses.append(validation_loss)\n",
    "    validation_accuracies.append(validation_accuracy)\n",
    "                    \n",
    "    print(f\"Epoch: {epoch + 1} - Train loss: {train_losses[-1]}, Train accuracy: {train_accuracies[-1]}, Validation loss: {validation_losses[-1]}, Validation accuracy: {validation_accuracies[-1]}\")"
   ],
   "id": "d8257602fcea271d",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Test",
   "id": "6399ad87bbfd16da"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "from dtrocr.model import DTrOCRLMHeadModel\n",
    "from dtrocr.config import DTrOCRConfig\n",
    "from dtrocr.processor import DTrOCRProcessor\n",
    "\n",
    "# model = DTrOCRLMHeadModel(DTrOCRConfig())\n",
    "model.eval()\n",
    "model.to('cpu')\n",
    "test_processor = DTrOCRProcessor(DTrOCRConfig())"
   ],
   "id": "b51d7fd369dee5ce",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "from PIL import Image\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from PIL import Image\n",
    "\n",
    "for test_word_record in test_word_records[:50]:\n",
    "    image_file = test_word_record.file_path\n",
    "    image = Image.open(image_file).convert('RGB')\n",
    "    \n",
    "    inputs = test_processor(\n",
    "        images=image, \n",
    "        texts=test_processor.tokeniser.bos_token,\n",
    "        return_tensors='pt'\n",
    "    )\n",
    "    \n",
    "    model_output = model.generate(\n",
    "        inputs, \n",
    "        test_processor, \n",
    "        num_beams=3\n",
    "    )\n",
    "    \n",
    "    predicted_text = test_processor.tokeniser.decode(model_output[0], skip_special_tokens=True)\n",
    "    \n",
    "    plt.figure(figsize=(10, 5))\n",
    "    plt.title(predicted_text, fontsize=24)\n",
    "    plt.imshow(np.array(image, dtype=np.uint8))\n",
    "    plt.xticks([]), plt.yticks([])\n",
    "    plt.show()"
   ],
   "id": "81e815a0557072e4",
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
